#! /usr/bin/python3

'''
Generate test modules with all interesting casts
'''

import argparse
import itertools

interesting_pairs = [('$super', '$super', 'cast-to-self-nonfinal'),
                     ('$sub-final', '$sub-final', 'cast-to-self-final'),
                     ('$sub', '$super', 'cast-to-super'),
                     ('$super', '$sub', 'cast-to-sub'),
                     ('$sub-final', '$sub', 'cast-to-sibling'),
                     ('$super', 'none', 'cast-to-bottom'),
                     ('none', '$super', 'cast-from-bottom')]


def gen_test_configs(args):
    for src_heap, cast_heap, heap_name in interesting_pairs:
        for src_nullable, src_exact, cast_nullable, cast_exact in \
                itertools.product([True, False], repeat=4):
            if src_exact and src_heap == 'none':
                continue
            if cast_exact and cast_heap == 'none':
                continue
            if args.enable_descs != cast_exact:
                continue
            yield heap_name, src_heap, cast_heap, src_nullable, cast_nullable, \
                src_exact, cast_exact


def print_test(config):
    heap_name, src_heap, cast_heap, src_nullable, cast_nullable, src_exact, \
        cast_exact = config

    src_nullable_name = 'null' if src_nullable else 'non-null'
    cast_nullable_name = 'null' if cast_nullable else 'non-null'

    src_exact_name = 'exact' if src_exact else 'inexact'
    cast_exact_name = 'exact' if cast_exact else 'inexact'

    test_name = f'{heap_name}-{src_nullable_name}-{src_exact_name}-to-' + \
        f'{cast_nullable_name}-{cast_exact_name}'

    src_nullable_type = ' null' if src_nullable else ''
    cast_nullable_type = ' null' if cast_nullable else ''

    src_heap_type = f'(exact {src_heap})' if src_exact else src_heap
    cast_heap_type = f'(exact {cast_heap})' if cast_exact else cast_heap

    src_type = f'(ref{src_nullable_type} {src_heap_type})'
    cast_type = f'(ref{cast_nullable_type} {cast_heap_type})'

    test = f'''
  (func ${test_name} (param {src_type}) (result {cast_type})
    (local anyref)
    (ref.cast {cast_type}
      (local.tee 1
        (local.get 0)
      )
    )
  )'''
    print(test)


def print_tests(args):
    for config in gen_test_configs(args):
        print_test(config)


def print_header(args):
    flags = ''
    if args.enable_descs:
        flags = ' --enable-custom-descs'
    header = f''';; NOTE: Assertions have been generated by update_lit_checks.py and should not be edited.
;; NOTE: Test has been generated by scripts/test/gen-cast-test.py{flags}. Do not edit manually.

;; Exhaustively test optimization of all interesting casts.
'''

    if args.enable_descs:
        header += '''
;; RUN: wasm-opt %s -all --optimize-instructions -S -o - | filecheck %s
'''
    else:
        header += '''
;; RUN: wasm-opt %s -all                              --optimize-instructions -S -o - | filecheck %s
;; RUN: wasm-opt %s -all --disable-custom-descriptors --optimize-instructions -S -o - | filecheck %s --check-prefix=NO_CD
'''

    header += '''
(module
  (type $super (sub (struct)))
  (type $sub (sub $super (struct)))
  (type $sub-final (sub final $super (struct)))'''

    print(header)


def print_footer():
    print(')')


def main():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--enable-custom-descs', action='store_true', dest='enable_descs')
    args = parser.parse_args()
    print_header(args)
    print_tests(args)
    print_footer()


if __name__ == '__main__':
    main()
